import torch
from torch_geometric.nn import knn_graph, ClusterGCNConv, DMoNPooling
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj
from Autoencoder import Encoder, Decoder


class ETEClusterModel(torch.nn.Module):
    def __init__(self, dim, encoder_size, num_neighbors, num_clusters, seq_len):
        super(ETEClusterModel, self).__init__()

        self.encoder = Encoder(
            seq_len=seq_len, n_features=dim, embedding_dim=encoder_size
        )
        self.decoder = Decoder(seq_len=seq_len, n_features=dim, input_dim=encoder_size)

        self.pool = DMoNPooling(channels=encoder_size, k=num_clusters)

        self.GCN = ClusterGCNConv(encoder_size, encoder_size)

        self.num_neighbors = num_neighbors

    def encode(self, x):
        h_n = self.encoder(x)

        return h_n

    def decode(self, x):
        x = self.decoder(x)

        return x

    def build_graph(self, x, batch=None):
        edge_index = knn_graph(
            x.squeeze(), k=self.num_neighbors, batch=batch, loop=False
        )

        return Data(x, edge_index)

    def cluster(self, x):
        s, out, out_adj, spectral_loss, ortho_loss, cluster_loss = self.pool(
            x.x, to_dense_adj(x.edge_index)
        )

        return s, spectral_loss, ortho_loss, cluster_loss

    def forward(self, inputs):
        enc = self.encode(inputs)

        dec = self.decode(enc)

        Rec_loss = torch.nn.MSELoss()(inputs, dec)

        G = self.build_graph(enc)

        G.x = self.GCN(G.x.squeeze(), G.edge_index)

        labels, spectral_loss, ortho_loss, cluster_loss = self.cluster(G)

        return enc, labels, spectral_loss, ortho_loss, cluster_loss, Rec_loss
